#!/usr/bin/env python3

# This file is based on clickhouse-test under the Apache license here:
# https://github.com/ClickHouse/ClickHouse/blob/master/tests/clickhouse-test

import sys
import os
import os.path
import re
import traceback
import copy

from argparse import ArgumentParser
import shlex
import subprocess
from subprocess import Popen
from subprocess import PIPE
from datetime import datetime
from time import time, sleep
from errno import ESRCH

try:
    import termcolor
except ImportError:
    termcolor = None
import multiprocessing
from contextlib import closing

MESSAGES_TO_RETRY = []


def remove_control_characters(s):
    """
    https://github.com/html5lib/html5lib-python/issues/96#issuecomment-43438438
    """

    def str_to_int(s, default, base=10):
        if int(s, base) < 0x10000:
            return chr(int(s, base))
        return default

    s = re.sub(r"&#(\d+);?", lambda c: str_to_int(c.group(1), c.group(0)), s)
    s = re.sub(
        r"&#[xX]([0-9a-fA-F]+);?",
        lambda c: str_to_int(c.group(1), c.group(0), base=16),
        s,
    )
    s = re.sub(r"[\x00-\x08\x0b\x0e-\x1f\x7f]", "", s)
    return s


def run_single_test(
    args,
    ext,
    client_options,
    case_file_full_path,
    stdout_file,
    stderr_file,
    result_file,
):
    params = {
        "client": args.client_with_database,
        "options": client_options,
        "test": case_file_full_path,
        "stdout": stdout_file,
        "stderr": stderr_file,
        "result": result_file,
    }

    pattern = "{test} > {stdout} 2>&1"
    if ext == ".sql":
        pattern = "sed '/^\\s*--/d' {test} | {client} {options} > {stdout} 2>&1"

    command = pattern.format(**params)
    proc = Popen(command, shell=True, env=os.environ)
    start_time = datetime.now()
    while (
        datetime.now() - start_time
    ).total_seconds() < args.timeout and proc.poll() is None:
        sleep(0.01)

    total_time = (datetime.now() - start_time).total_seconds()
    # Normalize randomized database names in stdout, stderr files.
    os.system(
        "LC_ALL=C sed -i -e 's/{test_db}/default/g' {file}".format(
            test_db=args.database, file=stdout_file
        )
    )

    if proc.returncode == 0:
        filter_fn = result_file.rsplit(".result")[0] + ".result_filter"
        if os.path.exists(filter_fn):
            with open(filter_fn, "rb") as f:
                filters = f.readlines()

            with open(stdout_file, "rb") as f:
                stdout_lines = f.readlines()

            res = []
            for line in stdout_lines:
                line = line.strip()
                for i in range(0, len(filters), 2):
                    flt = filters[i].strip()
                    repl = filters[i + 1].strip()

                    line = re.sub(flt, repl, line)

                res.append(line)

            stdout = b"\n".join(res) + b"\n"

            with open(stdout_file, "wb") as f:
                f.write(stdout)

    stdout = open(stdout_file, "rb").read() if os.path.exists(stdout_file) else b""
    stdout = str(stdout, errors="replace", encoding="utf-8")

    stderr = open(stderr_file, "rb").read() if os.path.exists(stderr_file) else b""
    stderr = str(stderr, errors="replace", encoding="utf-8")

    if proc.returncode == 0 and args.record == 1:
        os.system(
            "LC_ALL=C cp {stdout} {result}".format(
                stdout=stdout_file, result=result_file
            )
        )

    return proc, stdout, stderr, total_time


def need_retry(stderr):
    return any(msg in stderr for msg in MESSAGES_TO_RETRY)


def get_processlist(client_cmd):
    try:
        return subprocess.check_output(
            "{} --query 'SHOW PROCESSLIST FORMAT Vertical'".format(client_cmd),
            shell=True,
        )
    except Exception:
        return ""  # server seems dead


def get_stacktraces(server_pid):
    cmd = "gdb -batch -ex 'thread apply all backtrace' -p {}".format(server_pid)
    try:
        return subprocess.check_output(cmd, shell=True)
    except Exception as ex:
        return "Error occurred while receiving stack traces {}".format(str(ex))


def get_server_pid(server_tcp_port):
    cmd = "lsof - i tcp: {port} - s tcp: LISTEN - Fp | \
    awk '/^p[0-9]+$/{{print substr($0, 2)}}'".format(port=server_tcp_port)
    try:
        output = subprocess.check_output(cmd, shell=True)
        if output:
            return int(output[1:])
        else:
            return None  # server dead
    except Exception:
        return None


def colored(text, args, color=None, on_color=None, attrs=None):
    if termcolor and (sys.stdout.isatty() or args.force_color):
        return termcolor.colored(text, color, on_color, attrs)
    else:
        return text


SERVER_DIED = False
exit_code = 0
stop_time = None

# all_cases will be the full path of test case file.
# all_tests_with_params :
# [
#  [
#   [
#       [test_case0, suite_dir0, suit_tmp_dir0],
#       [test_case1, suite_dir1, suit_tmp_dir1],
#       ...
#   ],
#   suit,
#   run_total
#  ],
#   ...
# ]


def run_tests_array(all_tests_with_params):
    all_test_list, suite, run_total = all_tests_with_params
    global exit_code
    global SERVER_DIED
    global stop_time

    OP_SQUARE_BRACKET = colored("[", args, attrs=["bold"])
    CL_SQUARE_BRACKET = colored("]", args, attrs=["bold"])

    MSG_FAIL = (
        OP_SQUARE_BRACKET
        + colored(" FAIL ", args, "red", attrs=["bold"])
        + CL_SQUARE_BRACKET
    )
    MSG_UNKNOWN = (
        OP_SQUARE_BRACKET
        + colored(" UNKNOWN ", args, "yellow", attrs=["bold"])
        + CL_SQUARE_BRACKET
    )
    MSG_OK = (
        OP_SQUARE_BRACKET
        + colored(" OK ", args, "green", attrs=["bold"])
        + CL_SQUARE_BRACKET
    )
    MSG_SKIPPED = (
        OP_SQUARE_BRACKET
        + colored(" SKIPPED ", args, "cyan", attrs=["bold"])
        + CL_SQUARE_BRACKET
    )

    passed_total = 0
    skipped_total = 0
    failures_total = 0
    failures = 0
    failures_chain = 0
    failure_cases = []

    client_options = get_additional_client_options(args)

    if len(all_test_list):
        print("\nRunning {} {} tests.".format(len(all_test_list), suite) + "\n")

    def print_test_time(test_time):
        if args.print_time:
            return " {0:.2f} sec.".format(test_time)
        else:
            return ""

    for test_list in all_test_list:
        failures = 0
        if SERVER_DIED:
            break

        if stop_time and time() > stop_time:
            print("\nStop tests run because global time limit is exceeded.\n")
            break
        case = test_list[0].split("/")[-1]
        (name, ext) = os.path.splitext(case)

        try:
            status = ""
            is_concurrent = multiprocessing.current_process().name != "MainProcess"
            if not is_concurrent:
                sys.stdout.flush()
                sys.stdout.write("{0:72}".format(name + ": "))
                # This flush is needed so you can see the test name of the long
                # running test before it will finish. But don't do it in
                # parallel mode, so that the lines don't mix.
                sys.stdout.flush()
            else:
                status = "{0:72}".format(name + ": ")

            if args.skip and any([re.search(r, case) for r in args.skip]):
                status += MSG_SKIPPED + " - skip\n"
                skipped_total += 1
            else:
                case_file_full_path = test_list[0]
                suite_dir = test_list[1]
                suite_tmp_dir = test_list[2]
                disabled_file = os.path.join(suite_dir, name) + ".disabled"

                if os.path.exists(disabled_file) and not args.disabled:
                    message = open(disabled_file, "r").read()
                    status += MSG_SKIPPED + " - " + message + "\n"
                else:
                    file_suffix = (
                        ("." + str(os.getpid()))
                        if is_concurrent and args.test_runs > 1
                        else ""
                    )
                    result_file = os.path.join(suite_dir, name) + ".result"
                    cluster_result_file = (
                        os.path.join(suite_dir, name) + "_cluster.result"
                    )
                    stdout_file = (
                        os.path.join(suite_tmp_dir, name) + file_suffix + ".stdout"
                    )
                    stderr_file = (
                        os.path.join(suite_tmp_dir, name) + file_suffix + ".stderr"
                    )

                    if args.mode == "cluster" and os.path.isfile(cluster_result_file):
                        result_file = cluster_result_file
                        stdout_file = (
                            os.path.join(suite_tmp_dir, name)
                            + file_suffix
                            + "_cluster.stdout"
                        )
                        stderr_file = (
                            os.path.join(suite_tmp_dir, name)
                            + file_suffix
                            + "_cluster.stderr"
                        )

                    proc, stdout, stderr, total_time = run_single_test(
                        args,
                        ext,
                        client_options,
                        case_file_full_path,
                        stdout_file,
                        stderr_file,
                        result_file,
                    )

                    if proc.returncode is None:
                        try:
                            proc.kill()
                        except OSError as e:
                            if e.errno != ESRCH:
                                raise

                        failures += 1
                        failure_cases.append(case_file_full_path)
                        status += MSG_FAIL
                        status += print_test_time(total_time)
                        status += " - Timeout!\n"
                        if stderr:
                            status += stderr
                    else:
                        counter = 1
                        while proc.returncode != 0 and need_retry(stderr):
                            proc, stdout, stderr, total_time = run_single_test(
                                args,
                                ext,
                                client_options,
                                case_file_full_path,
                                stdout_file,
                                stderr_file,
                                result_file,
                            )
                            sleep(2**counter)
                            counter += 1
                            if counter > 6:
                                break

                        if proc.returncode != 0:
                            failures += 1
                            failure_cases.append(case_file_full_path)
                            failures_chain += 1
                            status += MSG_FAIL
                            status += print_test_time(total_time)
                            status += " - return code {}\n".format(proc.returncode)

                            if stderr:
                                status += stderr

                            # Stop on fatal errors like segmentation fault.
                            #  They are sent to client via logs.
                            if " <Fatal> " in stderr:
                                SERVER_DIED = True

                            if (
                                args.stop
                                and (
                                    "Connection refused" in stderr
                                    or "Attempt to read after eof" in stderr
                                )
                                and "Received exception from server" not in stderr
                            ):
                                SERVER_DIED = True

                            if os.path.isfile(stdout_file):
                                status += ", result:\n\n"
                                status += "\n".join(
                                    open(stdout_file).read().split("\n")[:300]
                                )
                                status += "\n"

                        elif stderr:
                            failures += 1
                            failure_cases.append(case_file_full_path)
                            failures_chain += 1
                            status += MSG_FAIL
                            status += print_test_time(total_time)
                            status += " - having stderror:\n{}\n".format(
                                "\n".join(stderr.split("\n")[:300])
                            )
                        elif "Exception" in stdout:
                            failures += 1
                            failure_cases.append(case_file_full_path)
                            failures_chain += 1
                            status += MSG_FAIL
                            status += print_test_time(total_time)
                            status += " - having exception:\n{}\n".format(
                                "\n".join(stdout.split("\n")[:300])
                            )
                        elif not os.path.isfile(result_file):
                            status += MSG_UNKNOWN
                            status += print_test_time(total_time)
                            status += " - no result file\n"
                        else:
                            result_is_different = subprocess.call(
                                ["diff", "-q", result_file, stdout_file], stdout=PIPE
                            )
                            if not args.complete and result_is_different:
                                diff = Popen(
                                    [
                                        "diff",
                                        "-U",
                                        str(args.unified),
                                        result_file,
                                        stdout_file,
                                    ],
                                    stdout=PIPE,
                                    universal_newlines=True,
                                ).communicate()[0]
                                failures += 1
                                failure_cases.append(case_file_full_path)
                                status += MSG_FAIL
                                status += print_test_time(total_time)
                                status += " - result differs with:\n{}\n".format(diff)
                            else:
                                if args.complete:
                                    o = Popen(
                                        ["cp", stdout_file, result_file],
                                        stdout=PIPE,
                                        universal_newlines=True,
                                    ).communicate()[0]
                                passed_total += 1
                                failures_chain = 0
                                status += MSG_OK
                                status += print_test_time(total_time)
                                status += "\n"
                                if os.path.exists(stdout_file):
                                    os.remove(stdout_file)
                                if os.path.exists(stderr_file):
                                    os.remove(stderr_file)

            if status and not status.endswith("\n"):
                status += "\n"

            sys.stdout.write(status)
            sys.stdout.flush()
        except KeyboardInterrupt as e:
            print(colored("Break tests execution", args, "red"))
            raise e
        except Exception:
            exc_type, exc_value, tb = sys.exc_info()
            failures += 1
            failure_cases.append("Test internal error")
            print(
                "{0} - Test internal error: {1}\n{2}\n{3}".format(
                    MSG_FAIL,
                    exc_type.__name__,
                    exc_value,
                    "\n".join(traceback.format_tb(tb, 10)),
                )
            )

        if failures_chain >= 20:
            break

        failures_total = failures_total + failures

    if failures_total > 0:
        print(
            colored(
                "\nHaving {failures_total} errors! {passed_total} tests passed.\
                     {skipped_total} tests skipped.".format(
                    passed_total=passed_total,
                    skipped_total=skipped_total,
                    failures_total=failures_total,
                ),
                args,
                "red",
                attrs=["bold"],
            )
        )
        print(colored("The failure tests:", args, "red", attrs=["bold"]))
        for case in failure_cases:
            print(colored("    {case}".format(case=case), args, "red", attrs=["bold"]))
        exit_code = 1
    else:
        print(
            colored(
                "\n{passed_total} tests passed. {skipped_total} tests skipped.".format(
                    passed_total=passed_total, skipped_total=skipped_total
                ),
                args,
                "green",
                attrs=["bold"],
            )
        )


def main(args):
    global SERVER_DIED
    global exit_code

    base_dir = os.path.abspath(args.suites)
    tmp_dir = os.path.abspath(args.tmp)

    os.environ.setdefault("QUERY_BINARY", args.binary)
    os.environ.setdefault("QUERY_DATABASE", args.database)

    databend_query_proc_create = Popen(
        shlex.split(args.client), stdin=PIPE, stdout=PIPE, stderr=PIPE
    )
    databend_query_proc_create.communicate(b"SELECT 1")

    # typos:ignore-next-line
    def sute_key_func(item):
        if -1 == item.find("_"):
            return 99998, ""
        prefix, suffix = item.split("_", 1)
        try:
            return int(prefix), suffix
        except ValueError:
            return 99997, ""

    def collect_subdirs_with_pattern(cur_dir_path, pattern):
        return list(
            # Make sure all sub-dir name starts with [0-9]+_*.
            filter(
                lambda fullpath: os.path.isdir(fullpath)
                and re.search(pattern, fullpath.split("/")[-1]),
                map(
                    lambda _dir: os.path.join(cur_dir_path, _dir),
                    os.listdir(cur_dir_path),
                ),
            )
        )

    def collect_files_with_pattern(cur_dir_path, patterns):
        return list(
            filter(
                lambda fullpath: os.path.isfile(fullpath)
                and os.path.splitext(fullpath)[1] in patterns.split("|"),
                map(
                    lambda _dir: os.path.join(cur_dir_path, _dir),
                    os.listdir(cur_dir_path),
                ),
            )
        )

    def get_all_tests_under_dir_recursive(suite_dir):
        all_tests = copy.deepcopy(collect_files_with_pattern(suite_dir, ".sql|.sh|.py"))
        # Collect files in depth 0 directory.
        sub_dir_paths = copy.deepcopy(
            collect_subdirs_with_pattern(suite_dir, "^[0-9]+")
        )
        # Recursively get files from sub-directories.
        while len(sub_dir_paths) > 0:
            cur_sub_dir_path = sub_dir_paths.pop(0)

            all_tests += copy.deepcopy(
                collect_files_with_pattern(cur_sub_dir_path, ".sql|.sh|.py")
            )

            sub_dir_paths += copy.deepcopy(
                collect_subdirs_with_pattern(cur_sub_dir_path, "^[0-9]+")
            )
        return all_tests

    total_tests_run = 0
    # typos:ignore-next-line
    for suite in sorted(os.listdir(base_dir), key=sute_key_func):
        if SERVER_DIED:
            break

        if args.skip_dir and any(s in suite for s in args.skip_dir):
            continue

        if args.run_dir and not any(s in suite for s in args.run_dir):
            continue

        # suite_dir should be changed when there is an subdirectory,
        # Need read the result file
        suite_dir = os.path.join(base_dir, suite)
        suite_re_obj = re.search("^[0-9]+_(.*)$", suite)
        if not suite_re_obj:  # skip .gitignore and so on
            continue

        # suite_tmp_dir is tmp directory, which can be unchanged.
        suite_tmp_dir = os.path.join(tmp_dir, suite)
        if not os.path.exists(suite_tmp_dir):
            os.makedirs(suite_tmp_dir)

        suite_pat = suite_re_obj.group(1)
        if os.path.isdir(suite_dir):
            # Reverse sort order: we want run newest test first.
            # And not reverse subtests
            def key_func(item):
                reverse = 1
                if -1 == item.split("/")[-1].find("_"):
                    return 99998, ""

                prefix, suffix = item.split("/")[-1].split("_", 1)

                try:
                    return reverse * int(prefix), suffix
                except ValueError:
                    return 99997, ""

            all_tests = get_all_tests_under_dir_recursive(suite_dir)

            if args.test:
                all_tests = [
                    t
                    for t in all_tests
                    if any([re.search(r, t.split("/")[-1]) for r in args.test])
                ]
            else:
                all_tests = [t for t in all_tests]

            bad_tests = filter(
                lambda case: re.search("^[0-9]+_[0-9]+_(.*)$", case.split("/")[-1])
                is None,
                all_tests,
            )
            bad_tests = [t for t in bad_tests]

            if len(bad_tests) > 0:
                print(
                    "Illegal test case names: {}, \
                    must match `^ [0-9]+_[0-9]+_(.*)$`".format(bad_tests)
                )
                sys.exit(1)

            all_tests.sort(key=key_func)
            run_n, run_total = args.parallel.split("/")
            run_n = float(run_n)
            run_total = float(run_total)
            tests_n = len(all_tests)
            if run_total > tests_n:
                run_total = tests_n
            if run_n > run_total:
                continue

            jobs = args.jobs
            if jobs > tests_n:
                jobs = tests_n
            if jobs > run_total:
                run_total = jobs
            all_tests_array = []

            for n in range(1, 1 + int(run_total)):
                _start = int(tests_n / run_total * (n - 1))
                _end = int(tests_n / run_total * n)
                _test_case_group_in_one_run = []
                _test_case_unit_group = []
                while _start < _end:
                    _test_case_unit = []
                    _updated_suite_dir = ""
                    _suite_tmp_dir = ""
                    assert len(all_tests[_start].split("/")) > len(suite_dir.split("/"))
                    # update suite_dir, there is a sub directory.
                    if (
                        len(all_tests[_start].split("/"))
                        != len(suite_dir.split("/")) + 1
                    ):
                        _suite_dir_len = len(suite_dir.split("/"))

                        _updated_suite_dir = os.path.join(
                            suite_dir,
                            "/".join(all_tests[_start].split("/")[_suite_dir_len:-1]),
                        )

                        _suite_tmp_dir = os.path.join(
                            suite_tmp_dir,
                            "/".join(all_tests[_start].split("/")[_suite_dir_len:-1]),
                        )

                    _test_case_unit.append(all_tests[_start])
                    _test_case_unit.append(_updated_suite_dir or suite_dir)
                    _test_case_unit.append(_suite_tmp_dir or suite_tmp_dir)
                    _test_case_unit_group.append(copy.deepcopy(_test_case_unit))
                    _start += 1

                _test_case_group_in_one_run.append(_test_case_unit_group)
                _test_case_group_in_one_run.append(suite_pat)
                _test_case_group_in_one_run.append(run_total)

                all_tests_array.append(copy.deepcopy(_test_case_group_in_one_run))

            if jobs > 1:
                with closing(multiprocessing.Pool(processes=jobs)) as pool:
                    pool.map(run_tests_array, all_tests_array)
                    pool.terminate()
            else:
                run_tests_array(all_tests_array[int(run_n) - 1])

            total_tests_run += tests_n

    if total_tests_run == 0:
        print("No tests were run.")
        sys.exit(1)

    sys.exit(exit_code)


def get_additional_client_options(args):
    return args.options


def get_additional_client_options_url(args):
    return ""


if __name__ == "__main__":
    parser = ArgumentParser(description="databend-query functional tests")
    parser.add_argument("-q", "--suites", help="Path to suites dir")
    parser.add_argument(
        "-b",
        "--binary",
        default="databend-query",
        help="Path to databend-query binary or name of binary in PATH",
    )
    parser.add_argument("-c", "--client", default="mysql", help="Client program")
    parser.add_argument(
        "-opt",
        "--options",
        default=" --comments --force ",
        help="Client program options",
    )
    parser.add_argument("--tmp", help="Path to tmp dir")
    parser.add_argument(
        "-t",
        "--timeout",
        type=int,
        default=900,
        help="Timeout for each test case in seconds",
    )
    parser.add_argument(
        "--record",
        type=int,
        default=0,
        help="Force override result files from stdout files",
    )
    parser.add_argument("test", nargs="*", help="Optional test case name regex")
    parser.add_argument(
        "--test-runs",
        default=1,
        nargs="?",
        type=int,
        help="Run each test many times (useful for e.g. flaky check)",
    )
    parser.add_argument(
        "-d",
        "--disabled",
        action="store_true",
        default=False,
        help="Also run disabled tests",
    )
    parser.add_argument("--force-color", action="store_true", default=False)
    parser.add_argument(
        "--print-time", action="store_true", dest="print_time", help="Print test time"
    )
    parser.add_argument(
        "-U",
        "--unified",
        default=3,
        type=int,
        help="output NUM lines of unified context",
    )
    parser.add_argument(
        "--database",
        default="default",
        help="Database for tests (random name test_XXXXXX by default)",
    )
    parser.add_argument(
        "--parallel", default="1/1", help="One parallel test run number/total"
    )
    parser.add_argument(
        "-j", "--jobs", default=1, nargs="?", type=int, help="Run all tests in parallel"
    )
    parser.add_argument("--skip", nargs="+", help="Skip these tests")
    parser.add_argument("--skip-dir", nargs="+", help="Skip all these tests in the dir")
    parser.add_argument("--run-dir", nargs="+", help="Only run these tests in the dir")
    parser.add_argument(
        "--complete", action="store_true", default=False, help="complete results"
    )
    parser.add_argument(
        "--stop",
        action="store_true",
        default=None,
        dest="stop",
        help="Stop on network errors",
    )
    parser.add_argument(
        "--mode",
        default="standalone",
        help="DatabendQuery running mode, \
            the value can be 'standalone' or 'cluster'",
    )

    args = parser.parse_args()
    if os.getenv("DATABEND_DEV_CONTAINER") is not None:
        args.client = "stdbuf -i0 -o0 -e0 {}".format(args.client)
    if args.suites is None and os.path.isdir("suites"):
        args.suites = "suites"
    if args.suites is None:
        print(
            "Failed to detect path to the suites directory. \
                Please specify it with '--suites' option."
        )
        exit(1)

    if args.tmp is None:
        args.tmp = args.suites

    args.client += " --user default -s"
    tcp_host = os.getenv("QUERY_MYSQL_HANDLER_HOST")
    if tcp_host is not None:
        args.tcp_host = tcp_host
        args.client += f" --host={tcp_host}"
    else:
        args.tcp_host = "127.0.0.1"
        args.client += f" --host={args.tcp_host}"

    tcp_port = os.getenv("QUERY_MYSQL_HANDLER_PORT")
    if tcp_port is not None:
        args.tcp_port = int(tcp_port)
        args.client += f" --port={tcp_port}"
    else:
        args.tcp_port = 3307
        args.client += f" --port={args.tcp_port}"

    args.client_with_database = args.client
    if not args.database:

        def random_str(length=6):
            import random
            import string

            alphabet = string.ascii_lowercase + string.digits
            return "".join(random.choice(alphabet) for _ in range(length))

        args.database = "test_{suffix}".format(suffix=random_str())
    args.client_with_database += " " + args.database
    main(args)
